Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds @zen decorator for task functions #310

Merged
merged 35 commits into from
Oct 19, 2022
Merged

Adds @zen decorator for task functions #310

merged 35 commits into from
Oct 19, 2022

Conversation

rsokl
Copy link
Contributor

@rsokl rsokl commented Sep 28, 2022

This PR introduces the zen decorator, which changes the interface of an arbitrary (Hydra-agnostic) function so that it can accept a Hydra config as its input. The decorator inspects a inner-function's signature and extracts (+resolves and instantiates) the appropriate fields from an input config to call said function:

from hydra_zen import make_config, zen

def func(a, b, c):
    return a + b + c
>>> zen_wrapped_func = zen(func)
>>> Cfg = make_config(a=1, b=builds(int, 2), c="${a}", unused=100)
# extracts Cfg.a/b/c, resolves interpolated fields, and instantiates targeted configs, and calls `func(a, b, c)`
>>> zen_wrapped_func(Cfg)  #  a=1 + b=2 + c=1 -> 4
4

Using @zen to improve task functions

@zen is designed to help decouple one's task function from the Hydra framework. By doing so, it improves the task function's legibility, versatility, and testability.

Given:

from hydra_zen import make_config, builds, just, instantiate

class Foo:
    def __init__(self, x: str) -> None:
        pass

Cfg = make_config(seed=1, foo=builds(Foo, x="bar"), unused=[1, 2])

One typically writes task functions as so:

def old_task_fn(cfg):
    seed = cfg.seed  # manually accessing and instantiating attributes
    foo = instantiate(cfg.foo)
    print(seed, foo)
>>> old_task_fn(Cfg)
1 <__main__.Foo object at 0x00000262295F0D00>

There are several issues with this:

  • The input – cfg – is relatively opaque to users and type-checkers alike.
  • The process of manually accessing and instantiating attributes is tedious and produces boilerplate code (it isn't so bad for this toy example, but it gets bad quickly in practice)
  • old_task_fn is tightly coupled to the Hydra framework – you must pass it a config to run it

@zen strives to rectify all of these shortcomings. Let's use it to refactor our task function:

from hydra_zen import zen

@zen
def new_task_fn(seed: int, foo: Foo):
    print(seed, foo)
>>> new_task_fn(Cfg)
1 <__main__.Foo object at 0x00000262295F0D00>

Here, @zen makes new_task_fn explicit, legible, and boilerplate free. It works by inspecting the signature of new_task_fn and extracting and instantiating the corresponding parameters from our config12. Given the explicit signature (with optional annotations), users and IDEs can easily understand the context of the task function's body.

Furthermore, one can run the underlying task function, via .func, independently of a Hydra app:

>>> new_task_fn.func(seed=10, foo=Foo("hi"))
10 <__main__.Foo object at 0x00000262295F4790>

Given this accessibility, and because our task-function is now free of Hydra-specific boilerplate code, we can easily use/test our task function outside of the context of our Hydra app.

zen makes it trivial to take any 3rd party function and transform it into a Hydra-compatible task function:

Using @zen instead of @hydra.main

The object returned by zen provides a convenience method -- Zen.hydra_main -- so that users need not double-wrap with @hydra.main to create a CLI:

# example.py
from hydra.core.config_store import ConfigStore

from hydra_zen import builds, zen

def f(x: int, y: int):
    print(x, y)

cs = ConfigStore.instance()
cs.store(name="my_app", node=builds(f, populate_full_signature=True))


if __name__ == "__main__":
    zen(f).hydra_main(config_name="my_app", config_path=None)
$ python example.py x=1 y=2
1 2

Additional Bells & Whistles

Validation

A zen-wrapped function can validate configs without calling the function itself. This makes it easy to test compatibility between your task functions and configs (e.g., as part of your CI/CD process)

>>> def f(x: int): ...
>>> zen_f = zen(f)
>>> zen_f.validate({"x": 1})  # OK
>>> zen_f.validate({"y": 1})  # Missing x
---------------------------------------------------------------------------
HydraZenValidationError: `cfg` is missing the following fields: x

Customizing the Wrapper Behavior

One can subclass hydra_zen.wrapper.Zen and pass it to @zen to modify the wrapped behavior.
In the following example we add the ability to log the config (as a yaml) upon each call of a zen-wrapped function.

from hydra_zen.wrapper import Zen
from hydra_zen import to_yaml, zen

class MyZen(Zen):
    def __call__(self, __cfg):
        print(f"Logged:\n{to_yaml(__cfg)}")
        return super().__call__(__cfg)
>>> @zen(ZenWrapper=MyZen)
... def fn(x: int, y: str): return (x, y)

>>> fn(dict(x=1, y="hi"))
Logged:
x: 1
'y': hi

(1, 'hi')

Adding a Pre-Call Step

Recall that @zen will automatically instantiate a sub-config prior to passing it to the decorated function. If that instantiated object relies on random behavior, it can be useful to be able to set a seed prior to the instantiation process. We can do this via @zen(pre_call=...):

Without pre-call:

import random
from hydra_zen import builds, zen


@zen
def f(rand_val: int):
    return rand_val


Cfg = dict(rand_val=builds(random.randint, 0, 10))
>>> [f(Cfg) for _ in range(10)]
[3, 8, 2, 4, 2, 1, 9, 4, 8, 9]

With pre-call:

import random
from hydra_zen import builds, zen

# Note that we use zen to also wrap the pre-call function: to extract `seed` from the config
@zen(pre_call=zen(lambda seed: random.seed(seed)))  
def f(rand_val: int):
    return rand_val


Cfg = dict(
        rand_val=builds(random.randint, 0, 10),
        seed=0,
)
>>> [f(Cfg) for _ in range(10)]
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6]

Validation propagates through zen-wrapped pre-call functions:

>>> def f(x: int): ...
>>> zen_f = zen(f, pre_call=zen(lambda seed: None))
>>> zen_f.validate({"x": 1, "seed": 10})  # OK
>>> zen_f.validate({"x": 1})  # Missing seed as required by pre-call 
---------------------------------------------------------------------------
HydraZenValidationError: `cfg` is missing the following fields: seed

Passing Through The Config

Some task functions require complete access to the full config to gain access to sub-configs. One can specify the field named zen_config3 in their task function's signature to signal zen that it should pass the full config to that parameter .

from hydra_zen import zen

def f(x: int, zen_cfg):
    return x, zen_cfg
>>> zen(f)(dict(x=1, y="${x}"))
(1, {'x': 1, 'y': 1})

Footnotes

  1. @zen only performs instantiation on extracted fields as-needed. Thus it avoids accessing/instantiating parts of a larger config that are not necessary for the given task function.

  2. Interpolated fields are resolved by calls mediated through zen

  3. You can change this specialized name by subclassing Zen

@rsokl rsokl added the enhancement New feature or request label Sep 28, 2022
@rsokl rsokl added this to the hydra-zen 0.9.0 milestone Sep 28, 2022
@Jasha10
Copy link
Contributor

Jasha10 commented Sep 28, 2022

I always like the idea of decoupling the application logic from the configuration framework (à la Bob Martin's dictum that frameworks should be kept at arms length). This @zen/@hydra_main feature makes decoupling easier.

With complex applications, it might be difficult to migrate from old_task_fn to new_task_fn all at once (e.g. if old_task_fn has many subconfigs or sub-subconfigs that are manually instantiated, possibly involving custom logic).

One possible pattern to make gradual migration easier would be for @zen to give special treatment to a zen_cfg keyword argument if it is present in the signature of new_task_fn, using zen_cfg to pass the Cfg object through unmodified. This would ease incremental migration:

from hydra_zen import builds, instantiate, make_config, zen

class Foo:
    def __init__(self, x: str) -> None:
        pass

Cfg = make_config(seed=1, foo=builds(Foo, x="bar"), unused=[1, 2])

# gradually migrate `old_task_fn` to `new_task_fn`:

def old_task_fn(cfg):
    seed = cfg.seed
    foo = instantiate(cfg.foo)
    print(seed, foo)

@zen
def new_task_fn_v0(zen_cfg):  # special `zen_cfg` keyword passes unmodified config
    cfg = zen_cfg
    seed = cfg.seed
    foo = instantiate(cfg.foo)
    print(seed, foo)

@zen
def new_task_fn_v1(seed, zen_cfg):
    cfg = zen_cfg
    foo = instantiate(cfg.foo)
    print(seed, foo)

@zen
def new_task_fn_v2(seed, foo, zen_cfg):
    cfg = zen_cfg
    print(seed, foo)

@zen
def new_task_fn_final(seed, foo):
    print(seed, foo)


# all of the below are equivalent:
old_task_fn(Cfg)
new_task_fn_v0(Cfg)
new_task_fn_v1(Cfg)
new_task_fn_v2(Cfg)
new_task_fn_final(Cfg)
$ # Below is the diff I used to accomplish this special treatment of the `zen_cfg` keyword argument
$ git diff
diff --git a/src/hydra_zen/_zen.py b/src/hydra_zen/_zen.py
index 5ba7e99..ecff97f 100644
--- a/src/hydra_zen/_zen.py
+++ b/src/hydra_zen/_zen.py
@@ -169,15 +169,20 @@ class Zen(Generic[P, T1]):
                 else getattr(cfg, name)
             )
             for name, param in self.parameters.items()
+            if name != "zen_cfg"
             if param.kind not in SKIPPED_PARAM_KINDS
         }

+        kwargs_final = {
+            name: instantiate(val) if is_instantiable(val) else val
+            for name, val in cfg_kwargs.items()
+        }
+        if "zen_cfg" in self.parameters:
+            kwargs_final["zen_cfg"] = cfg
+
         out = self.func(
             *(instantiate(x) if is_instantiable(x) else x for x in args_),
-            **{
-                name: instantiate(val) if is_instantiable(val) else val
-                for name, val in cfg_kwargs.items()
-            },
+            **kwargs_final,
         )  # type: ignore

         return out

@rsokl
Copy link
Contributor Author

rsokl commented Sep 28, 2022

Thanks for this, @Jasha10 ! I had been thinking of including an "escape hatch" like the one you sketched out with zen_cfg. And I agree with your reasoning about accommodating task functions that leverage complex sub-configs. I think that this is a good idea!

The one thing that I'll need to take care if is to make clear to the user that Hydra does not perform singleton instantiation. E.g. they might be tempted to believe that the following will hold:

@zen
def func(model, zen_cfg):
    assert model is instantiate(zen_cfg.model)  # <- this will fail!

(there are other ways, e.g. via interpolation, where this confusion could manifest, but those are independent of @zen)

I always like the idea of decoupling the application logic from the configuration framework (à la Bob Martin's dictum that frameworks should be kept at arms length).

This is sage advice! I have developed an intuition for this, but I never could have put it so succinctly. I will be sure to read this blog post 😄

.github/workflows/nightly.yml Outdated Show resolved Hide resolved
@rsokl rsokl marked this pull request as ready for review October 19, 2022 17:10
@rsokl rsokl merged commit 8e8a645 into main Oct 19, 2022
@rsokl rsokl deleted the add-zen branch October 19, 2022 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants